#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

This module provides DLG attack. Please refer to
README for more details.

"""

import os
import time

import numpy
import paddle.fluid as fluid
from PIL import Image

__all__ = ["dlg_attack"]


def dlg_attack(args, feature, label, network, exe, origin_grad):
    """
    The implementation of DLG attack.
    :param args: the parameters for dlg attack
    :param feature: the variable of feature
    :param label: the variable of label
    :param network: the network of model which is trained
    :param exe: the same executor with normal training procedure
    :param origin_grad: the original gradients of model params
    generated by target data, i.e., the data which is being attacked.
    :return:
    """
    main_program = fluid.Program()
    # use a new program
    with fluid.program_guard(main_program):
        # the dummy feature, which aims to imitate the target data
        dummy_x = fluid.data(name="dummy_x",
                             shape=list(feature.shape),
                             dtype=feature.dtype)
        # let dummy_x can be updated
        dummy_x.stop_gradient = False

        # the dummy_label
        dummy_y = fluid.data(name="dummy_y",
                             shape=list(label.shape),
                             dtype=label.dtype)
        # let dummy_y can be updated
        dummy_y.stop_gradient = False

        # use the model network of training
        _, dummy_loss = network(dummy_x, dummy_y)

        # get gradients of params that can be trainable
        all_params = main_program.global_block().all_parameters()
        grad_params = [param for param in all_params if param.trainable]
        dummy_grads = fluid.gradients(dummy_loss, grad_params)

        # original gradients
        origin_grad_vars = []
        for g_id, origin_g in enumerate(origin_grad):
            grad_name = "origin_g_" + str(g_id)
            grad_shape = origin_g.shape
            grad = fluid.data(name=grad_name,
                              shape=grad_shape,
                              dtype=origin_g.dtype)
            origin_grad_vars.append(grad)

        # the target loss of optimization, i.e., the difference
        # between gradients of model parameters generated respectively
        # by target data and dummy data
        diff_loss = 0.0
        for orig_g, dum_g in zip(origin_grad_vars, dummy_grads):
            cur_loss = fluid.layers.square_error_cost(orig_g, dum_g)
            cur_loss = fluid.layers.reduce_mean(cur_loss)
            diff_loss += cur_loss

        mean_diff_loss = fluid.layers.mean(diff_loss)

        # the gradient of dummy_x
        grad_of_x = fluid.gradients(mean_diff_loss, dummy_x)

    dummy_feature_shape = [1 if d == -1 else d for d in list(feature.shape)]
    dummy_label_shape = [1 if d == -1 else d for d in list(label.shape)]

    # Generate dummy target data. The main two types, i.e., float32 and int64,
    # are used here for feature and label variables respectively, which can be
    # changed according to different types in different scenarios.
    dummy_feature = numpy.random.normal(0, 1,
                                        size=dummy_feature_shape
                                        ).astype("float32")
    dummy_label = numpy.zeros(shape=dummy_label_shape).astype("int64")

    feed_dict = {}
    # add original gradients into feed_dict
    for idx, orig_g in enumerate(origin_grad):
        key = "origin_g_" + str(idx)
        feed_dict[key] = orig_g

    # the time of starting attack
    start = time.time()

    for iteration in range(args.iterations):
        feed_dict["dummy_x"] = dummy_feature
        feed_dict["dummy_y"] = dummy_label

        result = exe.run(main_program,
                         feed=feed_dict,
                         fetch_list=[mean_diff_loss] + grad_of_x)
        grad_diff_loss, feature_grad = result[0][0], result[1:]

        # update dummy_x with it's gradient
        feature_grad = numpy.array(feature_grad).reshape(dummy_feature_shape)
        dummy_feature = numpy.add(dummy_feature, args.learning_rate * feature_grad)
        dummy_feature = numpy.array(dummy_feature)

        # the shape of target image
        img_shape = dummy_feature_shape[-2:]

        # save attack results per 100 iterations
        if iteration % 100 == 0:
            print("Attack Iteration {}: grad_diff_loss = {}"
                  .format(iteration, grad_diff_loss))
            if not os.path.exists(args.result_dir):
                os.makedirs(args.result_dir)
            img = Image.fromarray((dummy_feature * 255)
                                  .reshape(img_shape)
                                  .astype(numpy.uint8))
            img.save(args.result_dir + "/result_{}.png".format(iteration))

    end = time.time()
    print("Attack cost time in seconds: {}".format(end - start))
    # exit after attack finished
    exit("Attack finished.")
